#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on 2021/10/01

@author: Feng
"""

# -*- coding: utf-8 -*-
"""
Spyder Editor

"""

import pickle
import hddm
import pandas as pd
import matplotlib.pyplot as plt
from kabuki.analyze import *
import pylab
import numpy
import os

root_path=os.getcwd()

gdata = hddm.load_csv(root_path+'./trial_hddm_eye.csv')

#%% aDDM4a: gaze effect on lottery weighting

os.chdir(root_path +'\\aDDM4a')
aDDM4a = hddm.HDDMRegressor(gdata, ["v ~ C(Role) + Price + LotteryGazeDurPriceRatio:C(Role) + LotteryGazeDurLotteryRatio:C(Role)", "z ~ C(Role)", "a ~ C(Role)"], group_only_regressors = False , include = {'a', 'v', 'z'} )
aDDM4a.find_starting_values()
aDDM4a.sample(12000, burn=2000, dbname='traces.db',db='pickle')

aDDM4a.print_stats()
aDDM4a.plot_posteriors()
aDDM4a.save('aDDM4a')

aDDM4a=hddm.load('aDDM4a')

aDDM4a_trace = aDDM4a.get_traces()
aDDM4a_trace.to_csv("./aDDM4a_trace.csv")

aDDM4a_post=aDDM4a.nodes_db
aDDM4a_post.to_csv("./aDDM4a_post.csv")

aDDM4a=hddm.load('aDDM4a')
stats = aDDM4a.gen_stats()
stats.to_csv("aDDM4a_stats.csv") 

temp=np.zeros((1,1))
ddm_dic=pd.DataFrame({'dic': temp[:,0]})
ddm_dic.dic=aDDM4a.dic
ddm_dic.to_csv("./aDDM4a_dic.csv")

temp=np.zeros((64,19))
post_tab=pd.DataFrame({'aBuy': temp[:, 0], \
                       'aSell': temp[:, 1], \
                       'aDelta': temp[:, 2], \
                       't': temp[:, 3], \
                       'zBuy': temp[:, 4], \
                       'zSell': temp[:, 5], \
                       'zDelta': temp[:, 6], \
                       'vInterceptBuy': temp[:, 7], \
                       'vInterceptSell': temp[:, 8], \
                       'vInterceptDelta': temp[:, 9], \
                       'vPrice': temp[:, 10], \
                       'vLotteryBuyGazeOnPrice': temp[:, 11], \
                       'vLotteryBuyGazeOnLottery': temp[:, 12], \
                       'vLotterySellGazeOnPrice': temp[:, 13], \
                       'vLotterySellGazeOnLottery': temp[:, 14], \
                       'wLotteryBuyGazeOnPrice': temp[:, 15], \
                       'wLotteryBuyGazeOnLottery': temp[:, 16], \
                       'wLotterySellGazeOnPrice': temp[:, 17], \
                       'wLotterySellGazeOnLottery': temp[:, 18],})

for isub in np.unique(gdata.subj_idx):
    post_tab.aBuy[isub-1]=aDDM4a_post.loc[ 'a_Intercept_subj.'+str(isub),'mean' ]
    post_tab.aDelta[isub-1]=aDDM4a_post.loc[ 'a_C(Role)[T.2]_subj.'+str(isub),'mean' ]
    post_tab.t[isub-1]=aDDM4a_post.loc[ 't_subj.'+str(isub),'mean' ]
    post_tab.zBuy[isub-1]=aDDM4a_post.loc[ 'z_Intercept_subj.'+str(isub),'mean' ]
    post_tab.zDelta[isub-1]=aDDM4a_post.loc[ 'z_C(Role)[T.2]_subj.'+str(isub),'mean' ]
    post_tab.vInterceptBuy[isub-1]=aDDM4a_post.loc[ 'v_Intercept_subj.'+str(isub),'mean' ]
    post_tab.vInterceptDelta[isub-1]=aDDM4a_post.loc[ 'v_C(Role)[T.2]_subj.'+str(isub),'mean' ]
    post_tab.vPrice[isub-1]=aDDM4a_post.loc[ 'v_Price_subj.'+str(isub),'mean' ] 
    post_tab.vLotteryBuyGazeOnPrice[isub-1]=aDDM4a_post.loc[ 'v_LotteryGazeDurPriceRatio:C(Role)[1]_subj.'+str(isub),'mean' ]
    post_tab.vLotteryBuyGazeOnLottery[isub-1]=aDDM4a_post.loc[ 'v_LotteryGazeDurLotteryRatio:C(Role)[1]_subj.'+str(isub),'mean' ]
    post_tab.vLotterySellGazeOnPrice[isub-1]=aDDM4a_post.loc[ 'v_LotteryGazeDurPriceRatio:C(Role)[2]_subj.'+str(isub),'mean' ]
    post_tab.vLotterySellGazeOnLottery[isub-1]=aDDM4a_post.loc[ 'v_LotteryGazeDurLotteryRatio:C(Role)[2]_subj.'+str(isub),'mean' ]
post_tab.aSell=post_tab.aBuy+post_tab.aDelta
post_tab.zSell=post_tab.zBuy+post_tab.zDelta
post_tab.vInterceptSell=post_tab.vInterceptBuy+post_tab.vInterceptDelta
post_tab.wLotteryBuyGazeOnPrice=np.true_divide(post_tab.vLotteryBuyGazeOnPrice, -post_tab.vPrice)
post_tab.wLotterySellGazeOnPrice=np.true_divide(post_tab.vLotterySellGazeOnPrice, -post_tab.vPrice)
post_tab.wLotteryBuyGazeOnLottery=np.true_divide(post_tab.vLotteryBuyGazeOnLottery, -post_tab.vPrice)
post_tab.wLotterySellGazeOnLottery=np.true_divide(post_tab.vLotterySellGazeOnLottery, -post_tab.vPrice)

post_tab.to_csv("./aDDM4a_post_subject.csv")


# simulate choice and rt

os.chdir(root_path+'\\aDDM4a\\')
aDDM4a=hddm.load('aDDM4a')

gdata['RTSim']=np.zeros(gdata.shape[0])

err=0.0001

post=aDDM4a.nodes_db

for idx in range(0,gdata.shape[0]):
    
    isub = gdata.subj_idx[idx]
    irole = gdata.Role[idx]
    
    t = post.loc[ 't_subj.'+str(isub),'mean' ]
    a = post.loc[ 'a_Intercept_subj.'+str(isub),'mean' ] + (irole-1)*post.loc[ 'a_C(Role)[T.2]_subj.'+str(isub),'mean' ]
    z = post.loc[ 'z_Intercept_subj.'+str(isub),'mean' ] + (irole-1)*post.loc[ 'z_C(Role)[T.2]_subj.'+str(isub),'mean' ]

    v_Intercept = post.loc[ 'v_Intercept_subj.'+str(isub),'mean' ] + (irole-1)*post.loc[ 'v_C(Role)[T.2]_subj.'+str(isub),'mean' ]
    v_Buy  = post.loc[ 'v_LotteryGazeDurPriceRatio:C(Role)[1]_subj.'+str(isub),'mean' ] * gdata.LotteryGazeDurPriceRatio[idx] + post.loc[ 'v_LotteryGazeDurLotteryRatio:C(Role)[1]_subj.'+str(isub),'mean' ] * gdata.LotteryGazeDurLotteryRatio[idx] + post.loc[ 'v_Price_subj.'+str(isub),'mean' ] * gdata.Price[idx]
    v_Sell = post.loc[ 'v_LotteryGazeDurPriceRatio:C(Role)[2]_subj.'+str(isub),'mean' ] * gdata.LotteryGazeDurPriceRatio[idx] + post.loc[ 'v_LotteryGazeDurLotteryRatio:C(Role)[2]_subj.'+str(isub),'mean' ] * gdata.LotteryGazeDurLotteryRatio[idx] + post.loc[ 'v_Price_subj.'+str(isub),'mean' ] * gdata.Price[idx]
    v = v_Intercept + (2-irole) * v_Buy + (irole-1) * v_Sell

    rt=np.arange(-10.000,10.001,0.001)
    pdf=hddm.wfpt.pdf_array(rt, v, 0, a, z, 0, t, 0, err)
    rt_pmax=rt[np.argmax(pdf)]
    
    gdata.RTSim[idx]=rt_pmax
  
gdata.to_csv(root_path+'\\aDDM4a\\aDDM4a_hddm_rt_simulated.csv')


#%% aDDM2a: no gaze effect on lottery weighting

# aDDM2a - z_buyer != z_seller, v_lottery_buyer = v_lottery_seller
os.chdir(root_path +'\\aDDM2a')
aDDM2a = hddm.HDDMRegressor(gdata, ["v ~ C(Role) + Price + Lottery:C(Role)", "z ~ C(Role)", "a ~ C(Role)"], group_only_regressors = False , include = {'a', 'v', 'z'} )
aDDM2a.find_starting_values()
aDDM2a.sample(12000, burn=2000, dbname='traces.db',db='pickle')

aDDM2a.print_stats()
aDDM2a.plot_posteriors()
aDDM2a.save('aDDM2a')

aDDM2a_trace = aDDM2a.get_traces()
aDDM2a_trace.to_csv("./aDDM2a_trace.csv")

aDDM2a_post=aDDM2a.nodes_db
aDDM2a_post.to_csv("./aDDM2a_post.csv")

aDDM2a=hddm.load('aDDM2a')
stats = aDDM2a.gen_stats()
stats.to_csv("aDDM2a_stats.csv")

temp=np.zeros((1,1))
ddm_dic=pd.DataFrame({'dic': temp[:,0]})
ddm_dic.dic=aDDM2a.dic
ddm_dic.to_csv("./aDDM2a_dic.csv") 


#%% first gaze effect

# aDDM4aa - first gaze effect on z; gaze effect on lottery weighting
os.chdir(root_path +'\\aDDM4aa')
aDDM4aa = hddm.HDDMRegressor(gdata, ["v ~ C(Role) + Price + LotteryGazeDurPriceRatio:C(Role) + LotteryGazeDurLotteryRatio:C(Role)", "z ~ C(Role) + C(FirstGazeLottery)", "a ~ C(Role)"], group_only_regressors = False , include = {'a', 'v', 'z'} )
aDDM4aa.find_starting_values()
aDDM4aa.sample(12000, burn=2000, dbname='traces.db',db='pickle')

aDDM4aa.print_stats()
aDDM4aa.plot_posteriors()
aDDM4aa.save('aDDM4aa')

aDDM4aa_trace = aDDM4aa.get_traces()
aDDM4aa_trace.to_csv("./aDDM4aa_trace.csv")

aDDM4aa_post=aDDM4aa.nodes_db
aDDM4aa_post.to_csv("./aDDM4aa_post.csv")

aDDM4aa=hddm.load('aDDM4aa')
stats = aDDM4aa.gen_stats()
stats.to_csv("aDDM4aa_stats.csv") 

temp=np.zeros((1,1))
ddm_dic=pd.DataFrame({'dic': temp[:,0]})
ddm_dic.dic=aDDM4aa.dic
ddm_dic.to_csv("./aDDM4aa_dic.csv")


# aDDM4aaa - first gaze effect on z by role; gaze effect on lottery weighting
os.chdir(root_path +'\\aDDM4aaa')
aDDM4aaa = hddm.HDDMRegressor(gdata, ["v ~ C(Role) + Price + LotteryGazeDurPriceRatio:C(Role) + LotteryGazeDurLotteryRatio:C(Role)", "z ~ C(Role) + C(FirstGazeLottery):C(Role)", "a ~ C(Role)"], group_only_regressors = False , include = {'a', 'v', 'z'} )
aDDM4aaa.find_starting_values()
aDDM4aaa.sample(12000, burn=2000, dbname='traces.db',db='pickle')

aDDM4aaa.print_stats()
aDDM4aaa.plot_posteriors()
aDDM4aaa.save('aDDM4aaa')

aDDM4aaa_trace = aDDM4aaa.get_traces()
aDDM4aaa_trace.to_csv("./aDDM4aaa_trace.csv")

aDDM4aaa_post=aDDM4aaa.nodes_db
aDDM4aaa_post.to_csv("./aDDM4aaa_post.csv")

aDDM4aaa=hddm.load('aDDM4aaa')
stats = aDDM4aaa.gen_stats()
stats.to_csv("aDDM4aaa_stats.csv") 

temp=np.zeros((1,1))
ddm_dic=pd.DataFrame({'dic': temp[:,0]})
ddm_dic.dic=aDDM4aaa.dic
ddm_dic.to_csv("./aDDM4aaa_dic.csv")



# aDDM4ac - first gaze effect on a; gaze effect on lottery weighting
os.chdir(root_path +'\\aDDM4ac')
aDDM4ac = hddm.HDDMRegressor(gdata, ["v ~ C(Role) + Price + LotteryGazeDurPriceRatio:C(Role) + LotteryGazeDurLotteryRatio:C(Role)", "z ~ C(Role)", "a ~ C(Role) + C(FirstGazeLottery)"], group_only_regressors = False , include = {'a', 'v', 'z'} )
aDDM4ac.find_starting_values()
aDDM4ac.sample(12000, burn=2000, dbname='traces.db',db='pickle')

aDDM4ac.print_stats()
aDDM4ac.plot_posteriors()
aDDM4ac.save('aDDM4ac')

aDDM4ac_trace = aDDM4ac.get_traces()
aDDM4ac_trace.to_csv("./aDDM4ac_trace.csv")

aDDM4ac_post=aDDM4ac.nodes_db
aDDM4ac_post.to_csv("./aDDM4ac_post.csv")

aDDM4ac=hddm.load('aDDM4ac')
stats = aDDM4ac.gen_stats()
stats.to_csv("aDDM4ac_stats.csv") 

temp=np.zeros((1,1))
ddm_dic=pd.DataFrame({'dic': temp[:,0]})
ddm_dic.dic=aDDM4ac.dic
ddm_dic.to_csv("./aDDM4ac_dic.csv")


# aDDM4aca - first gaze effect on a; gaze effect on lottery weighting
os.chdir(root_path +'\\aDDM4aca')
aDDM4aca = hddm.HDDMRegressor(gdata, ["v ~ C(Role) + Price + LotteryGazeDurPriceRatio:C(Role) + LotteryGazeDurLotteryRatio:C(Role)", "z ~ C(Role)", "a ~ C(Role) + C(FirstGazeLottery):C(Role)"], group_only_regressors = False , include = {'a', 'v', 'z'} )
aDDM4aca.find_starting_values()
aDDM4aca.sample(12000, burn=2000, dbname='traces.db',db='pickle')

aDDM4aca.print_stats()
aDDM4aca.plot_posteriors()
aDDM4aca.save('aDDM4aca')

aDDM4aca_trace = aDDM4aca.get_traces()
aDDM4aca_trace.to_csv("./aDDM4aca_trace.csv")

aDDM4aca_post=aDDM4aca.nodes_db
aDDM4aca_post.to_csv("./aDDM4aca_post.csv")

aDDM4aca=hddm.load('aDDM4aca')
stats = aDDM4aca.gen_stats()
stats.to_csv("aDDM4aca_stats.csv") 

temp=np.zeros((1,1))
ddm_dic=pd.DataFrame({'dic': temp[:,0]})
ddm_dic.dic=aDDM4aca.dic
ddm_dic.to_csv("./aDDM4aca_dic.csv")


#%%